# coding=utf-8
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
# Copyright (c) 2018, NVIDIA CORPORATION.  All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
""" Finetuning the library models for multiple choice (Bert, Roberta, XLNet)."""
import sys
# def my_tracer(frame, event, arg = None):
#     # extracts frame code
#     code = frame.f_code
  
#     # extracts calling function name
#     func_name = code.co_name
  
#     # extracts the line number
#     line_no = frame.f_lineno
  
#     print(f"A {event} encountered in \
#     {func_name}() at line number {line_no} ")
  
#     return my_tracer
# sys.settrace(my_tracer)
import logging
import os
os.environ["TOKENIZERS_PARALLELISM"] = "false"
os.environ["WANDB_DISABLED"] = "true"
from dataclasses import dataclass, field
from typing import Dict, Optional
from torch import nn as nn
import copy
import numpy as np
import torch

import transformers
from transformers import (
    AutoConfig,
    AutoModel,
    AutoTokenizer,
    EvalPrediction,
    HfArgumentParser,
    Trainer,
    TrainingArguments,
    set_seed,
    BertPreTrainedModel
)
from transformers.modeling_outputs import MultipleChoiceModelOutput
from transformers.trainer_utils import is_main_process
from utils_multiple_choice import MultipleChoiceDataset, Split, processors
from models.clip import load_clip
from models.utils_model import LXRTXLayer, BertPooler, BertPreTrainingHeads
logger = logging.getLogger(__name__)


def simple_accuracy(preds, labels):
    return (preds == labels).mean()


@dataclass
class ModelArguments:
    """
    Arguments pertaining to which model/config/tokenizer we are going to fine-tune from.
    """

    model_name_or_path: str = field(
        metadata={"help": "Path to pretrained model or model identifier from huggingface.co/models"}
    )
    xlayers: int = field(
        default=None,
        metadata={"help": "Number of Cross attention layers"},
    )
    config_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained config name or path if not the same as model_name"}
    )
    tokenizer_name: Optional[str] = field(
        default=None, metadata={"help": "Pretrained tokenizer name or path if not the same as model_name"}
    )
    cache_dir: Optional[str] = field(
        default=None,
        metadata={"help": "Where do you want to store the pretrained models downloaded from huggingface.co"},
    )
    


@dataclass
class DataTrainingArguments:
    """
    Arguments pertaining to what data we are going to input our model for training and eval.
    """

    task_name: str = field(metadata={"help": "The name of the task to train on: " + ", ".join(processors.keys())})
    data_dir: str = field(metadata={"help": "Should contain the data files for the task."})
    max_seq_length: int = field(
        default=128,
        metadata={
            "help": "The maximum total input sequence length after tokenization. Sequences longer "
            "than this will be truncated, sequences shorter will be padded."
        },
    )
    overwrite_cache: bool = field(
        default=False, metadata={"help": "Overwrite the cached training and evaluation sets"}
    )
    

class XATTNBERTForMultipleChoice(nn.Module):
    def __init__(self,args):

        
        config = AutoConfig.from_pretrained(
            args.config_name if args.config_name else args.model_name_or_path,
            cache_dir=args.cache_dir,
        )
        super().__init__()
        
        self.tokenizer = AutoTokenizer.from_pretrained(
            args.tokenizer_name if args.tokenizer_name else args.model_name_or_path,
            cache_dir=args.cache_dir,
        )
        model = AutoModel.from_pretrained(
            args.model_name_or_path,
            from_tf=bool(".ckpt" in args.model_name_or_path),
            config=config,
            cache_dir=args.cache_dir,
        )
        self.txtmodel = model

        # IMAGE model
        self.vismodel, _ = load_clip('ViT-B/32', "cuda", jit=  False)
        vis_config = copy.deepcopy(config)
        vis_config.hidden_size = 512 # clip hidden size
        vis_config.num_attention_heads = 8 # clip number of heads
        
        # CROSS model
        self.xmodel = nn.ModuleList(
            [LXRTXLayer(config,vis_config) for _ in range(args.xlayers)]
        )

        # POOLER
        self.pooler = BertPooler(config)
        #classifier
        #self.cls = BertPreTrainingHeads(config, None)
        # finetune task classifier
        self.dropout = nn.Dropout(config.hidden_dropout_prob)
        self.classifier = nn.Linear(config.hidden_size, 1)
        #print(self.classifier.weight.data[0][0])
        self.classifier.weight.data.normal_(mean=0.0, std=config.initializer_range)
        #print(self.classifier.weight.data[0][0])
        #self.classifier.load_state_dict()
        
        #tasks
        #self.task_matched = args.task_matched 

    def forward(
        self,
        input_ids=None,
        vis_input_ids=None,
        token_type_ids=None,
        attention_mask=None,
        visn_input_mask = None,
        sent = None,
        labels=None,
        position_ids=None,
        head_mask=None,
        inputs_embeds=None,
        output_attentions=None,
        output_hidden_states=None,
        return_dict=None,

    ):
        #print(input_ids,vis_input_ids)
        #s()
        #print(sent)
        #print(vis_input_ids)
        #s()
        #print(input_ids.shape,vis_input_ids.shape)
        batch_size, num_choices,nseg, seq_len = vis_input_ids.shape
        
        input_ids = input_ids.reshape(-1,input_ids.shape[-1])
        token_type_ids = token_type_ids.reshape(-1,input_ids.shape[-1])
        attention_mask = attention_mask.reshape(-1,input_ids.shape[-1])
        vis_input_ids = vis_input_ids.reshape(-1,vis_input_ids.shape[-1])
        #print(input_ids.shape,token_type_ids.shape,attention_mask.shape, vis_input_ids.shape)
        lang_feats = self.txtmodel(
            input_ids, token_type_ids = token_type_ids, attention_mask = attention_mask
            #visual_feats=(visual_feats, pos),
        ).last_hidden_state
        #print(lang_feats.shape, pooled_output.shape)
        #s()

        # visn_feats,visn_input_mask, _, _ = self.vismodel.encode_text(vis_input_ids)
        # visn_input_mask = visn_input_mask.reshape(batch_size*num_choices,nseg*77)
        # visn_feats = visn_feats.reshape(batch_size*num_choices,nseg*77,-1)
        # #print(visn_feats.shape)
        # #visn_feats=visn_feats.unsqueeze(1).float() # sequence of 1
        # #print(lang_feats.shape, input_mask.shape)

        # extended_attention_mask = attention_mask.unsqueeze(1).unsqueeze(2)

        # # Since attention_mask is 1.0 for positions we want to attend and 0.0 for
        # # masked positions, this operation will create a tensor which is 0.0 for
        # # positions we want to attend and -10000.0 for masked positions.
        # # Since we are adding it to the raw scores before the softmax, this is
        # # effectively the same as removing these entirely.
        # extended_attention_mask = extended_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        # extended_attention_mask = (1.0 - extended_attention_mask) * -10000.0
        
        # if visn_input_mask is not None:
        #     extended_visual_attention_mask = visn_input_mask.unsqueeze(1).unsqueeze(2)
        #     extended_visual_attention_mask = extended_visual_attention_mask.to(dtype=next(self.parameters()).dtype) # fp16 compatibility
        #     extended_visual_attention_mask = (1.0 - extended_visual_attention_mask) * -10000.0
        # else:
        #     extended_visual_attention_mask = None
        # #visn_feats = torch.zeros(visn_feats.shape).cuda()
        # #extended_visual_attention_mask = None
        # #print(extended_attention_mask.shape)
        # #print(lang_feats.shape,extended_attention_mask.shape,visn_feats.shape,extended_visual_attention_mask.shape)
        # for layer_module in self.xmodel:
        #     lang_feats, visn_feats = layer_module(lang_feats, extended_attention_mask,
        #                                           visn_feats, extended_visual_attention_mask)

        pooled_output = lang_feats[:,0]#self.pooler(lang_feats)
        #_, cross_relationship_score = self.cls(_, pooled_output)

        #total_loss = 0.
        
        #losses = ()
        # if masked_lm_labels is not None and self.task_mask_lm:
        #     masked_lm_loss = loss_fct(
        #         lang_prediction_scores.view(-1, self.config.vocab_size),
        #         masked_lm_labels.view(-1)
        #     )
        #     total_loss += masked_lm_loss
        #     losses += (masked_lm_loss.detach(),)
        #print(matched_label,self.task_matched)
        #pooled_output = self.dropout(pooled_output)
        logits = self.classifier(pooled_output)
        reshaped_logits = logits.view(-1, num_choices)
        #print(labels.shape,logits.shape)
        loss_fct = nn.CrossEntropyLoss()
        loss = loss_fct(reshaped_logits, labels)

        
        
        return MultipleChoiceModelOutput(
            loss=loss,
            logits=reshaped_logits,
            hidden_states=None,
            attentions=None,
        )

def main():
    # See all possible arguments in src/transformers/training_args.py
    # or by passing the --help flag to this script.
    # We now keep distinct sets of args, for a cleaner separation of concerns.

    parser = HfArgumentParser((ModelArguments, DataTrainingArguments, TrainingArguments))
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
    #training_args.
    if (
        os.path.exists(training_args.output_dir)
        and os.listdir(training_args.output_dir)
        and training_args.do_train
        and not training_args.overwrite_output_dir
    ):
        raise ValueError(
            f"Output directory ({training_args.output_dir}) already exists and is not empty. Use --overwrite_output_dir to overcome."
        )

    # Setup logging
    logging.basicConfig(
        format="%(asctime)s - %(levelname)s - %(name)s -   %(message)s",
        datefmt="%m/%d/%Y %H:%M:%S",
        level=logging.INFO if training_args.local_rank in [-1, 0] else logging.WARN,
    )
    logger.warning(
        "Process rank: %s, device: %s, n_gpu: %s, distributed training: %s, 16-bits training: %s",
        training_args.local_rank,
        training_args.device,
        training_args.n_gpu,
        bool(training_args.local_rank != -1),
        training_args.fp16,
    )
    # Set the verbosity to info of the Transformers logger (on main process only):
    if is_main_process(training_args.local_rank):
        transformers.utils.logging.set_verbosity_info()
        transformers.utils.logging.enable_default_handler()
        transformers.utils.logging.enable_explicit_format()
    logger.info("Training/evaluation parameters %s", training_args)

    # Set seed
    set_seed(training_args.seed)

    try:
        processor = processors[data_args.task_name]()
        label_list = processor.get_labels()
        num_labels = len(label_list)
    except KeyError:
        raise ValueError("Task not found: %s" % (data_args.task_name))

    # Load pretrained model and tokenizer
    #
    # Distributed training:
    # The .from_pretrained methods guarantee that only one local process can concurrently
    # download model & vocab.

    config = AutoConfig.from_pretrained(
        model_args.config_name if model_args.config_name else model_args.model_name_or_path,
        num_labels=num_labels,
        finetuning_task=data_args.task_name,
        cache_dir=model_args.cache_dir,
    )
    tokenizer = AutoTokenizer.from_pretrained(
        model_args.tokenizer_name if model_args.tokenizer_name else model_args.model_name_or_path,
        cache_dir=model_args.cache_dir,
    )
    # model = AutoModelForMultipleChoice.from_pretrained(
    #     model_args.model_name_or_path,
    #     from_tf=bool(".ckpt" in model_args.model_name_or_path),
    #     config=config,
    #     cache_dir=model_args.cache_dir,
    # )
    model = XATTNBERTForMultipleChoice(model_args)
    #for name,param in model.named_parameters():
    model.load_state_dict(torch.load(model_args.model_name_or_path),strict = False)
    #print(model)
    # for name,param in model.named_parameters():
    #     if "vismodel.token_embedding.weight" in name:
    #         print(name,param.data)

    # model2 = XATTNBERTForMultipleChoice(model_args)
    # #model2.load_state_dict(torch.load("results/pretrain/bert-base-uncased_clip_x2_wiki_mlm_vlm_match_whole_seq/INIT/INIT_XATTNBERT.pth"),strict = False)
    # state_dict = torch.load("results/pretrain/bert-base-uncased_clip_x2_wiki_mlm_vlm_match_whole_seq/INIT/INIT_XATTNBERT.pth")
    # # print(sorted(state_dict.keys()))
    # # print()
    # # d = {}
    # # for name,p in state_dict.items():
    # #     print(name,p)
    # #     d[name] = p.cuda()
    # #model2.load_state_dict(state_dict,strict = False)    #print(model)
    # #d = {}
    
    # for name,param in model.named_parameters():
    #     if "vismodel.token_embedding.weight" in name:
    #         print(name,param.data)
    # d1 = {}
    # d2 = {}    
    # with torch.no_grad():   
    #     for m1 in model.named_parameters():
    #         name, param = m1
    #         #name2,param2 = m2
    #         #assert name1 == name2
    #         if name in state_dict:
    #             #print(name,d1[name])
    #             #print(param)
    #             d1[name] = param.cuda()
    #             #state_dict[name] -= param.cuda()
    #         # elif "classifier" not in name:
                
    #         #     print("not found",)
    #         #     s()
                
    #         #     state_dict[name] = param.cuda()
    #         # if name2 in d:
    #         #     state_dict[name2] -= param
    #         # else:
    #         #     state_dict[name2] = param
    #     for m2 in model2.named_parameters():
    #         name, param = m2
    #         #name2,param2 = m2
    #         #assert name1 == name2
    #         if name in state_dict:
    #             #print(name,d1[name])
    #             #print(param)
    #             d2[name] = param.cuda()

    #     for k,v in d1.items():
    #         #if "vismodel.token_embedding.weight" in name:
    #         #diff = param -param2
    #         if k in d2 and "x" in k:
    #             print(k)
    #             v=  v.float()
    #             print(torch.norm(v-d2[k]))
        
    # s()
    #model.classifier.apply(model.init_bert_weights(model.classifier))
    # Get datasets
    train_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.train,
            vis = True

        )
        if training_args.do_train
        else None
    )
    eval_dataset = (
        MultipleChoiceDataset(
            data_dir=data_args.data_dir,
            tokenizer=tokenizer,
            task=data_args.task_name,
            max_seq_length=data_args.max_seq_length,
            overwrite_cache=data_args.overwrite_cache,
            mode=Split.dev,
            vis =True
        )
        if training_args.do_eval
        else None
    )

    def compute_metrics(p: EvalPrediction) -> Dict:
        preds = np.argmax(p.predictions, axis=1)
        return {"acc": simple_accuracy(preds, p.label_ids)}


    # Initialize our Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=train_dataset,
        eval_dataset=eval_dataset,
        
        compute_metrics=compute_metrics,
        #shuffle = True
    )
    # for t in train_dataset:
    #     try:
    #         t1 = torch.Tensor(t.input_ids)
    #     except:
    #         print(t.sent)
    #         for ids in t.input_ids:
    #             print(ids)

    #         s()
    #     #print(t1.shape)
    #     t2 = torch.LongTensor(t.vis_input_ids)
    #     #print(t2.shape)
    #     if t1.shape != torch.Size([4, 1, 128]) or t2.shape != torch.Size([4, 77]):
    #         print(t.sent)
    #         print(t1.shape)
    #         print(t2.shape)
    #         s()
        #s()
    # Training
    if training_args.do_train:
        trainer.train(
            model_path=model_args.model_name_or_path if os.path.isdir(model_args.model_name_or_path) else None
        )
        trainer.save_model()
        # For convenience, we also re-save the tokenizer to the same directory,
        # so that you can share your model easily on huggingface.co/models =)
        if trainer:#.is_world_master():
            tokenizer.save_pretrained(training_args.output_dir)

    # Evaluation
    results = {}
    if training_args.do_eval:
        logger.info("*** Evaluate ***")

        result = trainer.evaluate()

        output_eval_file = os.path.join(training_args.output_dir, "eval_results.txt")
        if trainer:#.is_world_master():
            with open(output_eval_file, "w") as writer:
                logger.info("***** Eval results *****")
                for key, value in result.items():
                    logger.info("  %s = %s", key, value)
                    writer.write("%s = %s\n" % (key, value))

                results.update(result)

    if training_args.do_predict:
        logger.info("*** Test ***")

        # Removing the `label` columns because it contains -1 and Trainer won't like that.
        #test_dataset.remove_columns_("label")
        preds_confidence = trainer.predict(test_dataset = eval_dataset).predictions
        predictions = np.argmax(preds_confidence[:,:2], axis=1)

        output_test_file = os.path.join(training_args.output_dir, f"test_results.txt")
        #if trainer.is_world_process_zero():
        with open(output_test_file, "w") as writer:
            logger.info(f"***** Test results *****")
            writer.write("index\tprediction\tconfidence1\tconfidence2\n")
            for index, (item,confidence) in enumerate(zip(predictions,preds_confidence)):
                item = label_list[item]
                writer.write(f"{index}\t{item}\t{confidence[0]}\t{confidence[1]}\n")

    return results


# def _mp_fn(index):
#     # For xla_spawn (TPUs)
#     main()


if __name__ == "__main__":
    main()
